package com.github.hgkmail.hello.leetcode101.sort;

import com.github.hgkmail.hello.leetcode101.base.CommonUtil;
import com.github.hgkmail.hello.leetcode101.base.ListNode;

//归并排序需要记忆，divide + conquer 的思想还是很厉害的，解题很常见
//数组归并排序: mergeSort, mergeSort2
//链表归并排序: mergeSortLinkedList
public class MergeSort {
    /**
     * 数组归并排序，左闭右开
     * @param nums 待排序数组
     * @param l 左边界（inclusive）
     * @param r 右边界（exclusive）
     * @param temp 临时数组，用于存放合并结果，避免排序过程[频繁分配内存]
     */
    public void mergeSort(int[] nums, int l, int r, int[] temp) {
        if (l+1>=r) {
            return;
        }
        //divide（代码框架）：求mid，然后分成2部分。
        int m=l+(r-l)/2;
        mergeSort(nums, l, m, temp);
        mergeSort(nums, m, r, temp);
        //conquer（代码框架）：3指针操作合并，先存在temp数组，注意循环的写法。
        int i=l; //pointer1: [l, m)
        int j=m; //pointer2: [m, r)
        int k=l; //pointer3: [l, r)
        //①+②这种写法可以统一合并剩余部分，简化代码！
        while (i < m || j < r) { //①只要有一个指针还没走完，就继续循环
            if (j>=r || (i<m && nums[i]<=nums[j])) { //②当j已走完或者i指向的元素比较小，取i的元素（条件的顺序不能颠倒）
                temp[k++]=nums[i++];
            } else {
                temp[k++]=nums[j++];
            }
        }
        //拷贝到原数组
        for (int n = l; n < r; n++) {
            nums[n] = temp[n];
        }
    }

    public void mergeSort2(int[] nums, int l, int r, int[] temp) {
        //end case
        if (l+1>=r) {
            return;
        }
        //divide（代码框架）：求mid，然后分成2部分。
        int m=l+(r-l)/2;
        mergeSort(nums, l, m, temp);
        mergeSort(nums, m, r, temp);
        //conquer（代码框架）：3指针操作合并，先存在temp数组，注意循环的写法。
        int i=l; //pointer1: [l, m)
        int j=m; //pointer2: [m, r)
        int k=l; //pointer3: [l, r)
        //3个循环的写法（这里用后置++的写法可以简化代码）
        while (i < m && j < r) {
            if (nums[i]<=nums[j]) {
                temp[k++]=nums[i++];
            } else {
                temp[k++]=nums[j++];
            }
        }
        while (i < m) temp[k++]=nums[i++];
        while (j < r) temp[k++]=nums[j++];
        //拷贝到原数组
        for (int n = l; n < r; n++) {
            nums[n] = temp[n];
        }
    }

    //无序链表归并排序
    public ListNode mergeSortLinkedList(ListNode head) {
        //end case
        if (head==null || head.next==null) {
            return head;
        }
        //divide（快慢指针找中点）
        ListNode fast=head;
        ListNode slow=head;
        if (fast.next!=null && fast.next.next!=null) {
            fast=fast.next.next;
            slow=slow.next;
        }
        ListNode mid = slow.next;
        slow.next=null;
        head = mergeSortLinkedList(head);
        mid = mergeSortLinkedList(mid);
        //conquer（合并2个有序链表）
        ListNode dummyHead = new ListNode(0, null);
        ListNode i=head, j=mid, k=dummyHead;
        while (i!=null && j!=null) {
            if (i.val<=j.val) {
                k.next=i;
                i=i.next;
            } else {
                k.next=j;
                j=j.next;
            }
            k=k.next;
        }
        k.next = i!=null ? i : j; //别忘了剩余部分
        return dummyHead.next;
    }

    public static void main(String[] args) {
        int[] nums = new int[]{1,5,3,10,-1,100,0,8};
        int[] temp = new int[nums.length];
//        new MergeSort().mergeSort(nums, 0, nums.length, temp);
//        new MergeSort().mergeSort2(nums, 0, nums.length, temp);
//        System.out.println(Arrays.toString(nums));

        ListNode a=new ListNode(3, null);
        ListNode b=new ListNode(1, a);
        ListNode c=new ListNode(2, b);
        ListNode d=new ListNode(4, c);
        CommonUtil.printLinkedList(new MergeSort().mergeSortLinkedList(d));
    }
}
